from dataclasses import dataclass

import numpy as np
from transformers import PreTrainedTokenizer


@dataclass(kw_only=True)
class TokenizedRewriting:
    prompt: np.ndarray
    target: np.ndarray


@dataclass(kw_only=True)
class TextRewriting:
    prompt: str
    target: str

    def tokenize(self, tokenizer: PreTrainedTokenizer):
        prompt_tokenized = np.asarray(tokenizer.encode(self.prompt), dtype=np.int64)
        target_tokenized = np.asarray(tokenizer.encode(self.target), dtype=np.int64)
        return TokenizedRewriting(prompt=prompt_tokenized, target=target_tokenized)
